from utils.conn import get_db


class Model:
    model_id = ''           # 模型的唯一ID，默认类名称小写，可以在子类中修改，访问统计中会用到，所以确定后不要修改
    table = ''              # 数据库表名，默认类名称小写，可在子类中修改
    primary_key = 'id'      # 数据表主键，默认为“id”，可在子类中修改
    last_id = 0             # insert、update时最后影响的id

    _form = None            # 保存客户端request post数据
    _where = None           # 保存客户端的查询条件
    _condition = ''         # 包含占位符的where条件
    _params = []            # 填充占位符的实际数据
    _data = None            # 查询结果数据

    def __init__(self):
        self.model_id = self.__class__.__name__.lower()

    def get_table_name(self):
        return self.table

    def _validate(self):
        pass

    def _after_get_one(self):
        pass

    def _after_get_list(self):
        pass

    def _before_do_insert(self):
        pass

    def _after_do_insert(self):
        pass

    def _before_do_update(self):
        pass

    def _after_do_update(self):
        pass

    def _before_do_delete(self):
        pass

    def _after_do_delete(self):
        pass

    def query_one(self, sql, params=None):
        """
        根据自定义sql语句获取一条记录
        """
        conn = get_db()
        cursor = conn.cursor()

        cursor.execute(sql.replace("$table", self.table), params)
        result = cursor.fetchone()

        cursor.close()
        conn.close()

        return result

    def query_all(self, sql, params=None):
        """
        根据自定义sql语句获取多条记录
        """
        conn = get_db()
        cursor = conn.cursor()

        cursor.execute(sql.replace("$table", self.table), params)
        result = cursor.fetchall()

        cursor.close()
        conn.close()

        return result

    def run_sql(self, sql, params):
        conn = get_db()
        cursor = conn.cursor()

        cursor.execute(sql.replace("$table", self.table), params)
        self.last_id = cursor.lastrowid
        conn.commit()

        cursor.close()
        conn.close()

        return self.last_id

    def get_column_names(self):
        """
        获取当前表的所有列名称
        适用pragma时不能传递任何参数，否则会出现错误：
            ValueError: parameters are of unsupported type
        参考地址：https://stackoverflow.com/questions/11996394
        """
        conn = get_db()
        cursor = conn.cursor()

        cursor.execute("PRAGMA table_info($table)".replace("$table", self.table))
        result = cursor.fetchall()

        cursor.close()
        conn.close()

        columns = []
        for line in result:
            columns.append(line['name'])

        return columns

    def get_one(self, where):
        """
        根据ID获取一条记录
        """
        self._where = where
        self.__build_where()

        sql = "select * from $table {} limit 1".format(self._condition)
        self._data = self.query_one(sql, self._params)
        self._after_get_one()

        return self._data

    @staticmethod
    def get_offset(page, page_size):
        page = int(page) if int(page) > 1 else 1
        return (page - 1) * int(page_size)

    def get_list(self, page, page_size, where=None, order="id desc", fields=None):
        """
        获取数据列表
        """
        self._where = where
        self.__build_where()

        offset = self.get_offset(page, page_size)

        sql = "select {} from $table {} order by {} limit ?, ?"\
            .format(self.__build_fields(fields), self._condition, order)

        self._data = self.query_all(sql, self._params + [offset, int(page_size)])
        self._after_get_list()

        sql = "select count(*) as total from $table {}".format(self._condition)
        total = self.query_one(sql, self._params)['total']

        return total, self._data

    def do_delete(self, where):
        """
        根据 id 删除记录
        """
        if not where:
            raise RuntimeWarning('条件为空，忽略本次删除')

        self._where = where

        self._before_do_delete()
        self.__build_where()

        sql = "delete from $table {}".format(self._condition)
        self.run_sql(sql, self._params)

        self._after_do_delete()

    def do_update(self, form, where):
        """
        更新记录
        """
        if not form or not where:
            raise RuntimeWarning('更新数据或更新条件为空，忽略本次更新')

        self._where = where
        self._form = form

        self._before_do_update()
        self._validate()
        self.__build_where()

        columns = list()
        params = list()
        for column in self.get_column_names():
            if column in self._form.keys():
                columns.append('`{}` = ?'.format(column))
                params.append(self._form[column])

        sql = "update $table set {} {}".format(','.join(columns), self._condition)
        self.run_sql(sql, params + self._params)

        self._after_do_update()

    def do_insert(self, form):
        """
        插入记录
        """
        if not form:
            raise RuntimeWarning('输入数据为空，忽略本次添加')

        self._form = form

        self._before_do_insert()
        self._validate()

        columns = list()
        params = list()
        for column in self.get_column_names():
            if column in self._form.keys():
                columns.append('`{}`'.format(column))
                params.append(self._form[column])

        placeholders = ','.join(['?'] * len(columns))

        sql = "insert into $table ({}) values ({})".format(','.join(columns), placeholders)
        self.run_sql(sql, params)
        insert_id = self.last_id

        self._after_do_insert()

        return insert_id

    def __build_where(self):
        conditions = []
        values = []

        if not self._where:
            return

        if self.__is_basic_type(self._where):
            conditions.append('`{}` = ?'.format(self.primary_key))
            values.append(self._where)
        elif isinstance(self._where, dict):
            for column in self._where.keys():
                value = self._where[column]

                if self.__is_basic_type(self._where[column]):
                    operator = "="
                elif isinstance(value, dict):
                    operator = list(value.keys())[0]
                    value = list(value.values())[0]
                elif isinstance(value, tuple):
                    operator = value[0]
                    value = value[1]
                else:
                    raise TypeError("不支持的参数格式")

                if value != "":
                    conditions.append('`{}` {} ?'.format(column, operator))
                    values.append(value)
        else:
            raise TypeError("不支持的参数格式")

        self._condition = ' where ' + ' and '.join(conditions)
        self._params = values

    @staticmethod
    def __is_basic_type(value):
        return isinstance(value, (int, float, str, complex))

    @staticmethod
    def __build_fields(fields=None):
        if fields is None:
            return "*"
        elif isinstance(fields, list):
            return ",".join(fields)
        else:
            raise RuntimeError("Parameter `fields` is illegal!")
